function dataout = ca_analysis_functions(datain, fflag)
    %%% global parameters %%%
    Fsi = 20;
    Fsb = 20;
    
    switch fflag
        case 'rescale'
            data = datain{1};
            [nc, nw, nm] = size(data);
            
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmp = data{i, j, k};
                        for ii = 1: size(tmp, 1)
                            tmp(ii, :) = normalize(tmp(ii, :));
                        end
                        data{i, j, k} = tmp;
                    end
                end
            end
            
            dataout = {data};
            
        case 'epoch_modify'
            data = datain{1};
            flag = datain{2};
            [nc, nw, nm] = size(data);
            
            datao = data;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmp = data{i, j, k}{flag};
                        for ii = 1: size(tmp, 1)
                            for jj = 1: size(tmp, 3)
                                tmpp = tmp(ii, :, jj);
                                tmpp(tmpp < 1e-5) = 1e-5 * rand(sum(tmpp < 1e-5), 1);
                                tmpp(isnan(tmpp)) = 1e-5 * rand(sum(isnan(tmpp)), 1);
                                tmp(ii, :, jj) = tmpp;
                            end
                        end
                        datao{i, j, k} = {tmp};
                    end
                end
            end
            
            dataout = {datao};
            
        case 'spk_data_prep'
            dff = datain{1};
            
            spks = dff;
            gs = dff;
            bs = dff;
            spkthres = dff;
            options.p = 2;
            datao = dff;
            for i = 1: size(spks, 1)
                for j = 1: size(spks, 2)
                    for k = 1: size(spks, 3)

                        %%% then prepare spk relavent variables %%%
                        t = dff{i, j, k};
                        gss = zeros(2, size(t, 1));
                        bss = zeros(size(t, 1), 1);
                        spkthress = zeros(size(t, 1), 1);
                        spkss = t;
                        dataos = t;
                        parfor ii = 1: size(t, 1)
                            tt = t(ii, :);
                            [cc, cb, c1, gn, sn, spk] = constrained_foopsi(tt, [], 0, [], [], options);
                            spkss(ii, :) = spk;
                            gss(:, ii) = gn;
                            bss(ii, :) = cb;
                            spkthress(ii, :) = nanmean(spk) + 2 * mad(spk);
                            dataos(ii, :) = cc + cb;
                        end
                        gs{i, j, k} = gss;
                        bs{i, j, k} = bss;
                        spkthres{i, j, k} = spkthress;
                        spks{i, j, k} = spkss;
                        datao{i, j, k} = dataos;
                    end
                end
            end
            
            dataout = {spks, gs, bs, spkthres, datao};
                        
        case 'behav_index'
            dff = datain{1};
            fdata = datain{2};
            
            idxs = dff;
            idxwh = dff;
            idxwp = dff;
%             ltc = 1 * Fsi;
            ltc = 0;
            slen = [2, 5, 2] * Fsi;
            wlen = round(2 * Fsi);
            [nc, nw, nm] = size(dff);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        nf = size(dff{i, j, k}, 2);
                        tmp = fdata.wh{i, j, k};
                        idxwh{i, j, k} = false(1, nf);
                        for kk = 1: size(tmp, 2)
                            idxwh{i, j, k}((ltc + tmp(1, kk)): min(nf, tmp(2, kk))) = true;
                        end
                        
                        tmp1 = fdata.wp{i, j, k};
                        idxwp{i, j, k} = false(1, nf);
                        for kk = 1: length(tmp1)
                            idxwp{i, j, k}((ltc + tmp1(kk)): min(nf, tmp1(kk) + wlen)) = true;
                        end
                        
                        tmp2 = fdata.s{i, j, k};
                        idxs{i, j, k} = false(1, nf);
                        for kk = 1: length(tmp2)
                            if ~isnan(tmp2(kk))
                                idxs{i, j, k}((ltc + max(1, tmp2(kk))): min(nf, tmp2(kk) + slen(i))) = true;
                            end
                        end
                    end
                end
            end
            
            dataout = {idxwp, idxwh, idxs};

        case 'extract_wipe'
            dataout = extract_factor(datain, fflag);
            
        case 'extract_whisk'
            dataout = extract_factor(datain, fflag);
            
        case 'extract_stim'
            dataout = extract_factor(datain, fflag);      
            
        case 'epoch_trial'
            datai = datain{1};
            fdata = datain{2};
            tstep = datain{3};
            [nc, nw, nm] = size(datai{1});
            
            %%% trial epoch: stimulus %%%
            datas = datai{1};
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        temp = cell(1, length(datai));
                        btc = fdata{i, j, k}; % input should specify which factor to use %
                        btc = btc(~isnan(btc));
                        for jj = 1: length(datai)
                            dtct = datai{jj}{i, j, k};
                            tpt = NaN(size(dtct, 1), Fsi * sum(tstep) + 1, length(btc));
                            for ii = 1: length(btc)
                                if btc(ii) > Fsi * tstep(1) && btc(ii) < size(dtct, 2) - Fsi * tstep(2)
                                    tmp = dtct(:, round(btc(ii) * Fsi / Fsb - Fsi * tstep(1)): round(btc(ii) * Fsi / Fsb + Fsi * tstep(2)));
                                    tpt(:, :, ii) = tmp;
                                end
                            end
                            idtt = ~isnan(squeeze(tpt(1, 1, :)));
                            tpt = tpt(:, :, idtt);
                            temp{jj} = tpt;
                        end
                        datas{i, j, k} = temp;
                    end
                end
            end
            
            dataout = {datas};
                        
        case 'epoch_trial_by_behavior'
            datas = datain{1}; % input epoched data %
            fdata = datain{2};
            fdatas = fdata{1};
            fdatab = fdata{2};
            [nc, nw, nm] = size(datas);
            
            %%% trial epoch: stimulus based on whisk or no whisk trials %%%
            dataswhnwh = datas;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dataswhnwh{i, j, k} = [];
                        for kk = 1: length(datas{i, j, k})
                            dtc = datas{i, j, k}{kk};
                            btc = fdatas{i, j, k};
                            wtc = fdatab{i, j, k};
                            if isempty(wtc)
                                dataswhnwh{i, j, k}{kk} = {dtc, dtc};
                            else
                                wtc = wtc(~isnan(btc));
                                idwh = wtc > 0;
                                idnwh = wtc == 0;
                                dataswhnwh{i, j, k}{kk} = {dtc(:, :, idwh), dtc(:, :, idnwh)};
                            end
                        end
                    end
                end
            end
            
            dataout = {dataswhnwh};
            
        case 'get_first_wipe_moment'
            fdata = datain{1};
            tstep = datain{2};
            fdatas = fdata.s;
            fdatab = fdata.b;
            fdatawp = fdata.wp;
            [nc, nw, nm] = size(fdatas);
            ithres = [2, 5]; %%% window to detect wipe after stim %%%
            
            %%% get wipe moment by trial %%%
            dataswhnwhwps = fdatas;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dataswhnwhwps{i, j, k} = [];
                        btc = fdatas{i, j, k};
                        whtc = fdatab{i, j, k};
                        wptc = fdatawp{i, j, k};
                        if ~isempty(whtc)
                            btc = btc(~isnan(btc));
                            whtc = whtc(~isnan(btc));
                            btc = btc(~isnan(whtc));
                            whtc = whtc(~isnan(whtc));
                            wpswh = [];
                            wpsnwh = [];
                            for ii = 1: length(btc)
                                btct = btc(ii);
                                tmp = wptc - btct;
                                idt = find(tmp > 0, 1);
                                tmp = tmp(idt);
                                tmp = tmp(tmp < ithres(4 - i) * Fsb);
                                if whtc(ii) > 0
                                    wpswh = [wpswh, tstep(1) * Fsi + tmp * Fsi / Fsb];
                                else
                                    wpsnwh = [wpsnwh, tstep(1) * Fsi + tmp * Fsi / Fsb];
                                end
                            end
                            dataswhnwhwps{i, j, k} = {wpswh, wpsnwh};
                        end
                    end
                end
            end
            
            dataout = {dataswhnwhwps};
            
        case 'concatenate_all_mice_noxious'
            dataswhnwh = datain{1};
            dataswhnwhwps = datain{2};
            tstep = datain{3};
            [nc, nw, nm] = size(dataswhnwh);
            
            %%% concatenate all mice based on whisk no whisk %%%
            dataswhnwhmc = cell(nc - 1, nw, 2);
            dataswhnwhwpsmc = cell(nc - 1, nw, 2);
            for i = 2: nc
                for j = 1: nw
                    for ii = 1: 2
                        dataswhnwhmc{i - 1, j, ii} = cell(1, length(dataswhnwh{i, j, 1}));
                        dataswhnwhwpsmc{i - 1, j, ii} = [];
                        for k = 1: nm
                            for kk = 1: length(dataswhnwh{i, j, k})
                                tp = dataswhnwh{i, j, k}{kk}{ii};
                                tp = mean(tp, 3);
                                tp = tp - mean(tp(:, 1: tstep(1) * Fsi), 2);
                                dataswhnwhmc{i - 1, j, ii}{kk} = [dataswhnwhmc{i - 1, j, ii}{kk}; tp];
                            end
                            tp = dataswhnwhwps{i, j, k}{ii};
                            dataswhnwhwpsmc{i - 1, j, ii} = [dataswhnwhwpsmc{i - 1, j, ii}, tp];
                        end
                    end
                end
            end
            
            dataout = {dataswhnwhmc, dataswhnwhwpsmc};
            
        case 'extract_pseudo'
            dataout = extract_factor(datain, fflag);
                        
        case 'smooth_spks'
            spks = datain{1};
            if length(datain) < 2
                wlen = 3;
            end
            
            fr = spks;
            [nc, nw, nm] = size(spks);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        spkst = spks{i, j, k};
                        spkst = convn(spkst, gausswin(wlen)'  / sum(gausswin(wlen)), 'same');
                        fr{i, j, k} = spkst;
                    end
                end
            end
            
            dataout = {fr};
            
        case 'get_average_spksss_trial_neuron'
            datasextsswhnwh = datain{1};
            gs = datain{2};
            dy = datain{3};
            if length(datain) < 3
                flag = 1;
            else
                flag = datain{3};
            end
            [nc, nw, nm] = size(datasextsswhnwh);
            ispks = 2;
            scl = 100;
            
            % vf %
            a = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(1, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % air %
            gsa = cell2mat(squeeze(gs(1, 1, :))');
            aa = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(1, 2, :), 'uniformoutput', false);
            aa = cell2mat(squeeze(aa)'); % air %
            gsaa = cell2mat(squeeze(gs(1, 2, :))');
            b = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(3, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % whisker + whisking %
            gsb = cell2mat(squeeze(gs(3, 1, :))');
            c = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(3, 1, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % whisker + whisking %
            gsc = cell2mat(squeeze(gs(3, 1, :))');
            d = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(3, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % whisker + whisking %
            gsd = cell2mat(squeeze(gs(3, 2, :))');
            e = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(3, 2, :), 'uniformoutput', false);
            e = cell2mat(squeeze(e)'); % whisker + whisking %
            gse = cell2mat(squeeze(gs(3, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(a(:));
                id = sum(a > thres, 1) > 0;
                a = a(:, id);
%                 thres = scl * mad(aa(:));
                id = sum(aa > thres, 1) > 0;
                aa = aa(:, id);
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
%                 thres = scl * mad(e(:));
                id = sum(e > thres, 1) > 0;
                e = e(:, id);
            end
            vfs = [mean(a, 2), mean(aa, 2), mean(b, 2), mean(c, 2), mean(d, 2), mean(e, 2)];
            vfgs = [mean(gsa, 2), mean(gsaa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2), mean(gse, 2)];

            % heat %
            b = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(2, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % whisker + whisking %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(2, 1, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % whisker + whisking %
            gsc = cell2mat(squeeze(gs(2, 1, :))');
            d = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(2, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % whisker + whisking %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            e = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(2, 2, :), 'uniformoutput', false);
            e = cell2mat(squeeze(e)'); % whisker + whisking %
            gse = cell2mat(squeeze(gs(2, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
%                 thres = scl * mad(e(:));
                id = sum(e > thres, 1) > 0;
                e = e(:, id);
            end
            heats = [mean(a, 2), mean(aa, 2), mean(b, 2), mean(c, 2), mean(d, 2), mean(e, 2)];
            heats(:, 3: end) = heats(:, 3: end) .* dy';
            heatgs = [mean(gsa, 2), mean(gsaa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2), mean(gse, 2)];
            
            dataout = {vfs, heats, vfgs, heatgs};
                        
        case 'group_neuron_spike_event_average'
            datasextsswhnwh = datain{1};
            gs = datain{2};
            dy = datain{3};
            if length(datain) < 4
                flag = 1;
            else
                flag = datain{4};
            end
            [nc, nw, nm] = size(datasextsswhnwh);
            ispks = 2;
            scl = 100;
            
            % vf %
            a = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(1, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % air %
            gsa = cell2mat(squeeze(gs(1, 1, :))');
            aa = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(1, 2, :), 'uniformoutput', false);
            aa = cell2mat(squeeze(aa)'); % air %
            gsaa = cell2mat(squeeze(gs(1, 2, :))');
            b = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(3, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % whisker + whisking %
            gsb = cell2mat(squeeze(gs(3, 1, :))');
            c = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(3, 1, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % whisker + whisking %
            gsc = cell2mat(squeeze(gs(3, 1, :))');
            d = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(3, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % whisker + whisking %
            gsd = cell2mat(squeeze(gs(3, 2, :))');
            e = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(3, 2, :), 'uniformoutput', false);
            e = cell2mat(squeeze(e)'); % whisker + whisking %
            gse = cell2mat(squeeze(gs(3, 2, :))');
            tmp = {a, aa, b, c, d, e};
            vfs = cell2mat(cellfun(@(x) mean(x > std(x, [], 1), 2), tmp, 'uniformoutput', false));
            vfgs = [mean(gsa, 2), mean(gsaa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2), mean(gse, 2)];

            % heat %
            b = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(2, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % whisker + whisking %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(2, 1, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % whisker + whisking %
            gsc = cell2mat(squeeze(gs(2, 1, :))');
            d = cellfun(@(x) reshape(permute(x{ispks}{1}, [2, 1, 3]), size(x{ispks}{1}, 2), []), datasextsswhnwh(2, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % whisker + whisking %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            e = cellfun(@(x) reshape(permute(x{ispks}{2}, [2, 1, 3]), size(x{ispks}{2}, 2), []), datasextsswhnwh(2, 2, :), 'uniformoutput', false);
            e = cell2mat(squeeze(e)'); % whisker + whisking %
            gse = cell2mat(squeeze(gs(2, 2, :))');
            heats = {a, aa, b, c, d, e};
            heats(3: end) = cellfun(@(x) x .* dy', heats(3: end), 'uniformoutput', false);
            heats = cell2mat(cellfun(@(x) mean(x > std(x, [], 1), 2), heats, 'uniformoutput', false));
            heatgs = [mean(gsa, 2), mean(gsaa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2), mean(gse, 2)];
            
            dataout = {vfs, heats, vfgs, heatgs};

        case 'active_neuron_trial_number'
            dataextswhnwh = datain{1};
            t = dataextswhnwh;
            
            scl = 1;
            % vf %
            a = cellfun(@(x) x{2}{1}, dataextswhnwh(1, 1, :), 'uniformoutput', false);
            a = squeeze(a); % air %
            aa = cellfun(@(x) x{2}{1}, dataextswhnwh(1, 2, :), 'uniformoutput', false);
            aa = squeeze(aa); % air %
            b = cellfun(@(x) x{2}{1}, dataextswhnwh(3, 1, :), 'uniformoutput', false);
            b = squeeze(b); % whisker + whisking %
            c = cellfun(@(x) x{2}{2}, dataextswhnwh(3, 1, :), 'uniformoutput', false);
            c = squeeze(c); % whisker + whisking %
            d = cellfun(@(x) x{2}{1}, dataextswhnwh(3, 2, :), 'uniformoutput', false);
            d = squeeze(d); % whisker + whisking %
            e = cellfun(@(x) x{2}{2}, dataextswhnwh(3, 2, :), 'uniformoutput', false);
            e = squeeze(e); % whisker + whisking %
            vfst = {a, aa, b, c, d, e};
            vfs = cell(1, 6);
            for i = 1: length(vfst)
                tmp = vfst{i};
                vfstt = [];
                for k = 1: length(tmp)
                    for j = 1: size(tmp{k}, 3)
                        ttmp = tmp{k}(:, :, j);
                        ttmp = ttmp(:, 2 * Fsi + 1: 4 * Fsi);
                        tt = sum(sum(ttmp > scl * squeeze(std(ttmp, [], 2)), 2) > 0);
                        vfstt(k, j) = tt / size(tmp{k}, 1);
                    end
                end
                vfs{i} = vfstt(vfstt > 0);
            end
%             a = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(1, 1, :), 'uniformoutput', false);
%             a = cell2mat(squeeze(a)'); % air %
%             aa = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(1, 2, :), 'uniformoutput', false);
%             aa = cell2mat(squeeze(aa)'); % air %
%             b = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(3, 1, :), 'uniformoutput', false);
%             b = cell2mat(squeeze(b)'); % whisker + whisking %
%             c = cellfun(@(x) reshape(permute(x{1}{2}, [2, 1, 3]), size(x{1}{2}, 2), []), dataextswhnwh(3, 1, :), 'uniformoutput', false);
%             c = cell2mat(squeeze(c)'); % whisker + whisking %
%             d = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(3, 2, :), 'uniformoutput', false);
%             d = cell2mat(squeeze(d)'); % whisker + whisking %
%             e = cellfun(@(x) reshape(permute(x{1}{2}, [2, 1, 3]), size(x{1}{2}, 2), []), dataextswhnwh(3, 2, :), 'uniformoutput', false);
%             e = cell2mat(squeeze(e)'); % whisker + whisking %
%             vfst = {a, aa, b, c, d, e};
%             vfs = [];
%             for i = 1: length(vfst)
%                 tmp = vfst{i};
%                 vfs(:, i) = [sum(sum(tmp > scl * std(tmp(:), [], 1), 1) > 0); size(tmp, 2)];
%                 %     vfs(:, i) = [sum(sum(tmp > 0, 1) > 0); size(tmp, 2)];
%             end
            
            
            % heat %
            b = cellfun(@(x) x{2}{1}, dataextswhnwh(2, 1, :), 'uniformoutput', false);
            b = squeeze(b); % whisker + whisking %
            c = cellfun(@(x) x{2}{2}, dataextswhnwh(2, 1, :), 'uniformoutput', false);
            c = squeeze(c); % whisker + whisking %
            d = cellfun(@(x) x{2}{1}, dataextswhnwh(2, 2, :), 'uniformoutput', false);
            d = squeeze(d); % whisker + whisking %
            e = cellfun(@(x) x{2}{2}, dataextswhnwh(2, 2, :), 'uniformoutput', false);
            e = squeeze(e); % whisker + whisking %
            heatst = {a, aa, b, c, d, e};
            heats = cell(1, 6);
            for i = 1: length(heatst)
                tmp = heatst{i};
                heatstt = [];
                for k = 1: length(tmp)
                    for j = 1: size(tmp{k}, 3)
                        ttmp = tmp{k}(:, :, j);
                        ttmp = ttmp(:, 2 * Fsi + 1: 7 * Fsi);
                        tt = sum(sum(ttmp > scl * squeeze(std(ttmp, [], 2)), 2) > 0);
                        heatstt(k, j) = tt / size(tmp{k}, 1);
                    end
                end
                heats{i} = heatstt(heatstt > 0);
            end
%             b = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(2, 1, :), 'uniformoutput', false);
%             b = cell2mat(squeeze(b)'); % whisker + whisking %
%             c = cellfun(@(x) reshape(permute(x{1}{2}, [2, 1, 3]), size(x{1}{2}, 2), []), dataextswhnwh(2, 1, :), 'uniformoutput', false);
%             c = cell2mat(squeeze(c)'); % whisker + whisking %
%             d = cellfun(@(x) reshape(permute(x{1}{1}, [2, 1, 3]), size(x{1}{1}, 2), []), dataextswhnwh(2, 2, :), 'uniformoutput', false);
%             d = cell2mat(squeeze(d)'); % whisker + whisking %
%             e = cellfun(@(x) reshape(permute(x{1}{2}, [2, 1, 3]), size(x{1}{2}, 2), []), dataextswhnwh(2, 2, :), 'uniformoutput', false);
%             e = cell2mat(squeeze(e)'); % whisker + whisking %
%             heatst = {a, aa, b, c, d, e};
%             heats = [];
%             for i = 1: length(heatst)
%                 tmp = heatst{i};
%                 heats(:, i) = [sum(sum(tmp > scl * std(tmp(:), [], 1), 1) > 0); size(tmp, 2)];
%                 %     heats(:, i) = [sum(sum(tmp > 0, 1) > 0); size(tmp, 2)];
%             end
            
            dataout = {vfs, heats};
            
        case 'convert_from_spk_to_ca_trace'
            spk = datain{1};
            gs = datain{2};
            [n1, n2] = size(spk);
            
            dff = spk;
            for i = 1: n2
                dff(:, i) = ar_integrate(spk(:, i), gs(:, i));
            end
            
            dataout = {dff};
            
        case 'get_epoch_behav_index'
            fwp = datain{1};
            [nc, nw, nm] = size(fwp);
            wlen = 2;
            
            fwpn = fwp;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        ft = fwpn{i, j, k};
                        if ~isempty(ft)
                            idx = false(1, max(ft(:)) + Fsi * wlen);
                            if size(ft, 1) == 1
                                for ii = 1: length(ft)
                                    idx(ft(ii): ft(ii) + Fsi * wlen) = true;
                                end
                            else
                                for ii = 1: length(ft)
                                    idx(ft(1, ii): ft(2, ii)) = true;
                                end
                            end
                            [l, n] = bwlabeln(idx);
                            ft = zeros(1, n);
                            for ii = 1: n
                                ft(ii) = find(l == ii, 1);
                            end
                            fwpn{i, j, k} = ft;
                        end
                        disp(num2str([i, j, k]))
                    end
                end
            end

            dataout = {fwpn};
            
        case 'compute_STA'
            dataexts = datain{1};
            fdata = datain{2};
            wsize = [-2, 2];
            xn = diff(wsize) * Fsi + 1;
            fdatas = fdata.s;
            fdatab = fdata.b;
            fdatawp = fdata.wp;
            [nc, nw, nm] = size(fdatas);
            ithres = [2, 5];
            
            dataextssta = fdatas;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        dataextssta{i, j, k} = [];
                        btc = fdatas{i, j, k};
                        whtc = fdatab{i, j, k};
                        wptc = fdatawp{i, j, k};
                        dtmp = dataexts{i, j, k};
                        if ~isempty(whtc)
                            btc = btc(~isnan(btc));
                            whtc = whtc(~isnan(btc));
                            btc = btc(~isnan(whtc));
                            whtc = whtc(~isnan(whtc));
                            stawh = [];
                            stanwh = [];
                            for ii = 1: length(btc)
                                btct = btc(ii);
                                tmp = wptc - btct;
                                idt = find(tmp > 0, 1);
                                tmp = tmp(idt);
                                if tmp < ithres(4 - i) * Fsb
                                    wptt = wptc(idt);
                                    tw = (wsize(1) * Fsi: wsize(2) * Fsi) + wptt * Fsi / Fsb;
                                    if whtc(ii) > 0
                                        stawh = cat(3, stawh, dtmp(:, tw));
                                    else
                                        stanwh = cat(3, stanwh, dtmp(:, tw));
                                    end
                                end
                            end
                            dataextssta{i, j, k} = {stawh, stanwh};
                        end
                    end
                end
            end
            
            dataout = {dataextssta(2: end, :, :)};
            
        case 'compute_STA_group_average'
            sta = datain{1};
            gs = datain{2};
            flag = datain{3};
            sk = 1;
            
            a = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(2, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % vf whisker + whisk %
            gsa = cell2mat(squeeze(gs(2, 1, :))');
            b = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(2, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % vf whisker + no whisk %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(2, 2, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % vf no whisker + whisk %
            gsc = cell2mat(squeeze(gs(2, 2, :))');
            d = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(2, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % vf no whisker + no whisk %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(a(:));
                id = sum(a > thres, 1) > 0;
                a = a(:, id);
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
            end
%             vfst = {a, b, c, d};
%             vfs = zeros(size(vfst{1}, 1), length(vfst));
%             for ii = 1: length(vfst)
%                 tmp = vfst{ii};
%                 tmp = convn(tmp, gausswin(sk) / sum(gausswin(sk)), 'same');
%                 vfs(:, ii) = mean(tmp, 2);
%             end
            vfs = [mean(a, 2), mean(b, 2), mean(c, 2), mean(d, 2)];
            vfgs = [mean(gsa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2)];

            a = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(1, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % vf whisker + whisk %
            gsa = cell2mat(squeeze(gs(2, 1, :))');
            b = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(1, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % vf whisker + no whisk %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(1, 2, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % vf no whisker + whisk %
            gsc = cell2mat(squeeze(gs(2, 2, :))');
            d = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(1, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % vf no whisker + no whisk %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(a(:));
                id = sum(a > thres, 1) > 0;
                a = a(:, id);
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
            end
%             heatst = {a, b, c, d};
%             heats = zeros(size(heatst{1}, 1), length(heatst));
%             for ii = 1: length(heatst)
%                 tmp = heatst{ii};
%                 tmp = convn(tmp, gausswin(sk) / sum(gausswin(sk)), 'same');
%                 heats(:, ii) = mean(tmp, 2);
%             end
            heats = [mean(a, 2), mean(b, 2), mean(c, 2), mean(d, 2)];
            heatgs = [mean(gsa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2)];

            dataout = {vfs, heats, vfgs, heatgs};

        case 'compute_STA_group'
            sta = datain{1};
            gs = datain{2};
            flag = datain{3};
            sk = 9;
            
            a = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(2, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % vf whisker + whisk %
            gsa = cell2mat(squeeze(gs(2, 1, :))');
            b = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(2, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % vf whisker + no whisk %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(2, 2, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % vf no whisker + whisk %
            gsc = cell2mat(squeeze(gs(2, 2, :))');
            d = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(2, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % vf no whisker + no whisk %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(a(:));
                id = sum(a > thres, 1) > 0;
                a = a(:, id);
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
            end
            vfs = {a, b, c, d};
%             for ii = 1: length(vfs)
%                 tmp = vfs{ii};
%                 tmp = convn(tmp, gausswin(sk) / sum(gausswin(sk)), 'same');
%                 vfs{ii} = tmp;
%             end
            vfgs = [mean(gsa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2)];

            a = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(1, 1, :), 'uniformoutput', false);
            a = cell2mat(squeeze(a)'); % vf whisker + whisk %
            gsa = cell2mat(squeeze(gs(2, 1, :))');
            b = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(1, 1, :), 'uniformoutput', false);
            b = cell2mat(squeeze(b)'); % vf whisker + no whisk %
            gsb = cell2mat(squeeze(gs(2, 1, :))');
            c = cellfun(@(x) reshape(permute(x{1}, [2, 1, 3]), size(x{1}, 2), []), sta(1, 2, :), 'uniformoutput', false);
            c = cell2mat(squeeze(c)'); % vf no whisker + whisk %
            gsc = cell2mat(squeeze(gs(2, 2, :))');
            d = cellfun(@(x) reshape(permute(x{2}, [2, 1, 3]), size(x{2}, 2), []), sta(1, 2, :), 'uniformoutput', false);
            d = cell2mat(squeeze(d)'); % vf no whisker + no whisk %
            gsd = cell2mat(squeeze(gs(2, 2, :))');
            if flag == 2
                thres = 0;
%                 thres = scl * mad(a(:));
                id = sum(a > thres, 1) > 0;
                a = a(:, id);
%                 thres = scl * mad(b(:));
                id = sum(b > thres, 1) > 0;
                b = b(:, id);
%                 thres = scl * mad(c(:));
                id = sum(c > thres, 1) > 0;
                c = c(:, id);
%                 thres = scl * mad(d(:));
                id = sum(d > thres, 1) > 0;
                d = d(:, id);
            end
            heats = {a, b, c, d};
%             for ii = 1: length(heats)
%                 tmp = heats{ii};
%                 tmp = convn(tmp, gausswin(sk) / sum(gausswin(sk)), 'same');
%                 heats{ii} = tmp;
%             end
            heatgs = [mean(gsa, 2), mean(gsb, 2), mean(gsc, 2), mean(gsd, 2)];

            dataout = {vfs, heats, vfgs, heatgs};
            
        case 'sct_register'
            roiall = datain{1};
            [nc, nw, nm] = size(roiall);
            pixh = 200;
            pixw = 320;
            
            idall = cell(nm, nc * nw);
            structall = cell(nm, nc * nw);
            for i = 1: nm
                roitmp = squeeze(cellfun(@(x) permute(reshape(x, pixh, pixw, []), [3, 1, 2]), reshape(roiall(:, :, i), nc * nw, []), 'uniformoutput', false))';
                for j = 1: nc * nw
                    roitmpt = [roitmp(j: end), roitmp(1: j - 1)];
                    tmp = cellreg_logdemons(roitmpt);
                    idall{i, j} = tmp.cell_to_index_map;
                    structall{i, j} = tmp;
                end
            end
            
            dataout = {idall, structall};
            
        case 'sct_compute_graph_all_found'
            idall = datain{1};
            [nm, nn] = size(idall);
            
            idfinal = cell(1, nm);
            for i = 1: nm
                tmp = idall(i, :);
                rtmp = zeros(min(cellfun(@min, cellfun(@(x) max(x, [], 1), tmp, 'uniformoutput', false))), nn - 1);
                [~, idt] = min(max(tmp{1}, [], 1));
                for j = 1: nn
                    tmp{j} = [tmp{j}(:, nn - j + 2: end), tmp{j}(:, 1: nn - j + 1)];
                    tmp{j} = [tmp{j}(:, idt: end), tmp{j}(:, 1: idt - 1)];
                end
                
                for j = 1: size(rtmp, 1)
                    tmpt = zeros(nm, nn - 1);
                    for k = 1: nn
                        idtt = tmp{k}(:, 1) == j;
                        if sum(idtt) > 0
                            tmpt(k, :) = tmp{k}(idtt, 2: end);
                        end
                    end
                    for k = 1: nn - 1
                        tmpp = tmpt(:, k);
                        tmpp = tmpp(tmpp > 0);
                        rtmp(j, k) = mode(tmpp);
                    end
                end
                idfinal{i} = [(1: size(rtmp, 1))', rtmp];
                idtt = idfinal{i};
                idtt = sum(idtt > 0, 2) == size(idtt, 2);
                idfinal{i} = idfinal{i}(idtt, [nn - idt + 2: nn, 1: nn - idt + 1]);
            end
            
            dataout = {idfinal};
            
        case 'sct_get_traces_all_found'
            spks = datain{1}{1};
            dffs = datain{1}{2};
            idfinal = datain{2};
            [nc, nw, nm] = size(spks);
            
            spkaf = spks;
            dffaf = dffs;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        idt = idfinal{k};
                        spkaf{i, j, k} = spks{i, j, k}(idt(:, i + (j - 1) * nc), :);
                        dffaf{i, j, k} = dffs{i, j, k}(idt(:, i + (j - 1) * nc), :);
                    end
                end
            end
            
            dataout = {spkaf, dffaf};
 
        case 'sct_get_traces_epoch_all_found'
            idfinal = datain{1};
            datas = datain{2};
            datawp = datain{3};
            datawh = datain{4};
            dataswhnwh = datain{5};
            dataexts = datain{6};
            dataextwp = datain{7};
            dataextwh = datain{8};
            dataextswhnwh = datain{9};
            [nc, nw, nm] = size(datas);
            
            datasaf = datas;
            datawpaf = datawp;
            datawhaf = datawh;
            dataswhnwhaf = dataswhnwh;
            dataextsaf = dataexts;
            dataextwpaf = dataextwp;
            dataextwhaf = dataextwh;
            dataextswhnwhaf = dataextswhnwh;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        idt = idfinal{k};
                        for ii = 1: length(datasaf{i, j, k})
                            datasaf{i, j, k}{ii} = datasaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(datawpaf{i, j, k})
                            datawpaf{i, j, k}{ii} = datawpaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(datawhaf{i, j, k})
                            datawhaf{i, j, k}{ii} = datawhaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(dataswhnwhaf{i, j, k})
                            for jj = 1: length(dataswhnwhaf{i, j, k}{ii})
                                dataswhnwhaf{i, j, k}{ii}{jj} = dataswhnwhaf{i, j, k}{ii}{jj}(idt(:, i + (j - 1) * nc), :, :);
                            end
                        end
                        for ii = 1: length(dataextsaf{i, j, k})
                            dataextsaf{i, j, k}{ii} = dataextsaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(dataextwpaf{i, j, k})
                            dataextwpaf{i, j, k}{ii} = dataextwpaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(dataextwhaf{i, j, k})
                            dataextwhaf{i, j, k}{ii} = dataextwhaf{i, j, k}{ii}(idt(:, i + (j - 1) * nc), :, :);
                        end
                        for ii = 1: length(dataextswhnwhaf{i, j, k})
                            for jj = 1: length(dataextswhnwhaf{i, j, k}{ii})
                                dataextswhnwhaf{i, j, k}{ii}{jj} = dataextswhnwhaf{i, j, k}{ii}{jj}(idt(:, i + (j - 1) * nc), :, :);
                            end
                        end
                    end
                end
            end
            
            dataout = {datasaf, datawpaf, datawhaf, dataswhnwhaf, dataextsaf, dataextwpaf, dataextwhaf, dataextswhnwhaf};

        case 'sct_neuron_mean_std_event_frequency'
            data = datain{1};
            [nc, nw, nm] = size(data);
            nn = length(data{1, 1, 1});
            rgy = zeros(nn, 2);
            for i = 1: nn
                rgy(i, :) = [0, 1.1 * max(cell2mat(cellfun(@(x) mean(squeeze(max(x{i},[], 2)), 2), data(:), 'uniformoutput', false)))];
            end
            
            stats = cell(nc, nw);
            for i = 1: nc
                for j = 1: nw
                    stats{i, j} = cell(1, length(data{i, j, 1}));
                    for ii = 1: length(stats{i, j})
                        stats{i, j}{ii} = [];
                    end
                    for k = 1: nm
                        for ii = 1: length(data{i, j, k})
                            tmp = data{i, j, k}{ii};
                            tmp1 = squeeze(max(tmp, [], 2));
%                             tmp2 = sum(tmp1 > 1 * std(tmp(:)), 2) ./ size(tmp1, 2);
                            tmp2 = sum(tmp1 > 1 * std(tmp(:)), 2) ./ size(tmp1, 2);
                            stats{i, j}{ii} = [stats{i, j}{ii}; mean(tmp1, 2), std(tmp1, [], 2) ./ sqrt(size(tmp1, 2)), tmp2(:)];
                        end
                    end
                end
            end
            
            dataout = {stats, rgy};

        case 'sct_neuron_mean_std_event_frequency_dist'
            stats = datain{1};
            rgy = datain{2};
            [nc, nw] = size(stats);
            nn = 40;
            
            distr = stats;
            for i = 1: nc
                for j = 1: nw
                    for ii = 1: length(stats{i, j})
                        tmp = stats{i, j}{ii};
                        distr{i, j}{ii} = [ksdensity(tmp(:, 3), linspace(0, 1, nn))', ksdensity(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn))'];
%                         distr{i, j}{ii} = [ksdensity(tmp(:, 3), linspace(0, 1, nn), 'bandwidth', 2 * 1 / nn)', ksdensity(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn), 'bandwidth', diff(2 * rgy(ii, :) / nn))'];
%                         n1 = histcounts(tmp(:, 3), linspace(0, 1, nn + 1));
%                         n2 = histcounts(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn + 1));
%                         distr{i, j}{ii} = [n1(:), n2(:)];
                    end
                end
            end
            
            nn = length(distr{1, 1});
            rgy = cell(1, nn);
            for i = 1: nn
                rgy{i} = [0, 0; 1.1 * max(cell2mat(cellfun(@(x) squeeze(max(x{i}, [], 1)), distr(:), 'uniformoutput', false)))]';
            end
            
            dataout = {distr, rgy};

        case 'sct_neuron_mean_std_event_frequency_whnwh'
            data = datain{1};
            [nc, nw, nm] = size(data);
            nn = length(data{1, 1, 1});
            rgy = zeros(nn, 2);
            for i = 1: nn
                rgy(i, :) = [0, 1.1 * max(cell2mat(cellfun(@(x) max(cellfun(@(y) max(mean(squeeze(max(y, [], 2)), 2)), x{i}),[], 2), data(:), 'uniformoutput', false)))];
            end
            
            stats = cell(nc, nw);
            for i = 1: nc
                for j = 1: nw
                    stats{i, j} = cell(1, length(data{i, j, 1}));
                    for ii = 1: length(stats{i, j})
                        stats{i, j}{ii} = cell(1, length(data{i, j, 1}{ii}));
                        for jj = 1: length(data{i, j, 1}{ii})
                            stats{i, j}{ii}{jj} = [];
                        end
                    end
                    for k = 1: nm
                        for ii = 1: length(data{i, j, k})
                            for jj = 1: length(data{i, j, 1}{ii})
                                tmp = data{i, j, k}{ii}{jj};
                                tmp1 = squeeze(max(tmp, [], 2));
                                tmp2 = sum(tmp1 > 2 * std(tmp(:)), 2) ./ size(tmp1, 2);
                                stats{i, j}{ii}{jj} = [stats{i, j}{ii}{jj}; mean(tmp1, 2), std(tmp1, [], 2) ./ sqrt(size(tmp1, 2)), tmp2(:)];
                            end
                        end
                    end
                end
            end
            
            dataout = {stats, rgy};

        case 'sct_neuron_mean_std_event_frequency_dist_whnwh'
            stats = datain{1};
            rgy = datain{2};
            [nc, nw] = size(stats);
            nn = 40;
            
            distr = stats;
            for i = 1: nc
                for j = 1: nw
                    for ii = 1: length(stats{i, j})
                        for jj = 1: length(stats{i, j}{ii})
                            tmp = stats{i, j}{ii}{jj};
                            distr{i, j}{ii}{jj} = [ksdensity(tmp(:, 3), linspace(0, 1, nn))', ksdensity(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn))'];
%                             distr{i, j}{ii} = [ksdensity(tmp(:, 3), linspace(0, 1, nn), 'bandwidth', 2 * 1 / nn)', ksdensity(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn), 'bandwidth', diff(2 * rgy(ii, :) / nn))'];
%                             n1 = histcounts(tmp(:, 3), linspace(0, 1, nn + 1));
%                             n2 = histcounts(tmp(:, 1), linspace(rgy(ii, 1), rgy(ii, 2), nn + 1));
%                             distr{i, j}{ii} = [n1(:), n2(:)];
                        end
                    end
                end
            end
            
            nn = length(distr{1, 1});
            rgy = cell(1, nn);
            for i = 1: nn
                rgy{i} = [0, 0; 1.1 * max(cell2mat(cellfun(@(x) max(cell2mat(cellfun(@(y) max(y, [], 1), x{i}, 'uniformoutput', false)'), [], 1), distr(:), 'uniformoutput', false)))]';
            end
            
            dataout = {distr, rgy};           
            
        case 'sct_same_neuron_compare_max_intensity'
            data = datain{1};
            [nc, nw, nm] = size(data);
            
            sct_compare_int = cell(1, nm);
            for i = 1: nm
                for j = 1: length(data{1, 1, 1})
                    ref = max(mean(data{1, 1, i}{j}, 3), [], 2);
                    count = 1;
                    sct_compare_int{i}{j} = [];
                    for ii = 1: nw
                        for jj = 1: nc
                            tmpt = max(mean(data{jj, ii, i}{j}, 3), [], 2);
                            tmpt = tmpt ./ ref;
                            sct_compare_int{i}{j}(:, count) = tmpt;
                            count = count + 1;
                        end
                    end
                    tmp = sct_compare_int{i}{j};
                    tmp = tmp(~isnan(tmp(:, 1)), :);
                    sct_compare_int{i}{j} = tmp;
                end
            end
            
            dataout = {sct_compare_int};
            
        case 'sct_same_neuron_compare_event_frequency'
            data = datain{1};
            [nc, nw, nm] = size(data);
            wid = [2, 2, 5];
            Fs = 20;
            
            sct_compare_freq = cell(1, nm);
            for i = 1: nm
                for j = 1: length(data{1, 1, 1})
                    tmp = data{1, 1, i}{j};
                    thres = std(tmp(:));
                    tmp1 = tmp > thres;
                    ref = mean(mean(tmp1(:, 2 * Fs + 1: 4 * Fs, :), 2), 3);
                    count = 1;
                    sct_compare_freq{i}{j} = [];
                    for ii = 1: nw
                        for jj = 1: nc
                            tmpt = data{jj, ii, i}{j};
                            thres = std(tmpt(:));
                            tmpt1 = tmpt > thres;
                            tmpt = mean(mean(tmpt1(:, 2 * Fs + 1: (2 + wid(jj)) * Fs, :), 2), 3);
                            tmpt = tmpt ./ ref;
                            sct_compare_freq{i}{j}(:, count) = tmpt;
                            count = count + 1;
                        end
                    end
                    tmp = sct_compare_freq{i}{j};
                    tmp = tmp(~isnan(tmp(:, 1)), :);
                    sct_compare_freq{i}{j} = tmp;
                end
            end
            
            dataout = {sct_compare_freq};
            
        case 'sct_same_neuron_compare_ratio_stats'
            dataint = datain{1};
            datafreq = datain{2};
            tp1 = cell(1, length(dataint{1})); 
            tp2 = cell(1, length(dataint{1})); 
            for i = 1: length(dataint{1})
                tmp = cell2mat(cellfun(@(x) x{i}, dataint, 'uniformoutput', false)');
                tp1{i} = log10(tmp(~isnan(tmp(:, 1)) & sum(abs(tmp) == Inf | tmp == 0, 2) == 0, :));
                tmp = cell2mat(cellfun(@(x) x{i}, datafreq, 'uniformoutput', false)');
                tp2{i} = log10(tmp(~isnan(tmp(:, 1)) & sum(abs(tmp) == Inf | tmp == 0, 2) == 0, :));
            end
            
            incdec = cell(2, length(dataint{1}));
            ids = cell(2, length(dataint{1}));
            for ii = 1: length(dataint{1})
                for i = 1: 5
                    tmp = tp1{ii}(:, 2: end) > 0;
                    tmpp = unique(tmp, 'rows');
                    ids{1, ii} = tmpp;
                    for j = 1: size(tmpp, 1)
                        incdec{1, ii}(j) = sum(ismember(tmp, tmpp(j, :), 'rows'));
                    end
                    
                    tmp = tp2{ii}(:, 2: end) > 0;
                    tmpp = unique(tmp, 'rows');
                    ids{2, ii} = tmpp;
                    for j = 1: size(tmpp, 1)
                        incdec{2, ii}(j) = sum(ismember(tmp, tmpp(j, :), 'rows'));
                    end
                end
            end
            
            dataout = {incdec, ids};
            
        case 'sct_same_neuron_compare_active_ornot'
            dataa = datain{1};
            datae = datain{2};
            flag = datain{3};
            rg = [2, 5, 2];
            [nc, nw, nm] = size(dataa);
            threst = Fsi * 0.1;
            
            scout = cell(1, nm);
            for i = 1: nm
                count = 1;
                for j = 1: nc
                    for k = 1: nw
                        tmpa = dataa{j, k, i};
                        tmpe = datae{j, k, i}{flag};
                        thres = 1 * std(tmpa, [], 2) + min(tmpa, [], 2);
                        tmpt = tmpe > thres;
                        scout{i}(:, count) = sum(sum(tmpt(:, 2 * Fsi + 1: (2 + rg(j)) * Fsi + 1, :), 2) > threst, 3) ./ size(tmpt, 3);
                        count = count + 1;
                    end
                end
            end
            
            dataout = {scout};
            
        case 'sct_condition_distribution'
            scin = datain{1};
            nm = length(scin);
            
            datao = zeros(size(scin{1}, 2), nm);
            dataor = zeros(size(scin{1}, 2), nm);
            dataoin = zeros(size(scin{1}, 2), nm);
            for i = 1: nm
                tmp = scin{i};
                
                %%% binarize the active ratio %%%
                thres = 0.1;
                tmp1 = tmp > thres;
                
                %%% get all condition active neuron number %%%
                datao(:, i) = sum(tmp1, 1);
                dataor(:, i) = sum(tmp1, 1) ./ size(tmp1, 1);
                for j = 1: size(tmp, 2)
                    dataoin(j, i) = sum(tmp1(:, 1) & tmp1(:, j), 1);
                end
            end
            
            dataout = {datao, dataor, dataoin};

        %%% ---------------------------- tca ------------------- %%%
        
        case 'sct_condition_distribution_overlap'
            scin = datain{1};
            nm = length(scin);
            
            datao = cell(1, nm);
            for i = 1: nm
                tmp = scin{i};
                
                %%% binarize the active ratio %%%
                thres = 0.1;
                tmp1 = tmp > thres;
                
                datao{i} = tmp1;
            end
            
            dataout = {datao};

        %%% ---------------------------- tca ------------------- %%%

        case 'tca'
            data = datain{1};
            if length(datain) < 2
                flag = 1;
            else
                flag = datain{2};
            end
            if length(datain) < 3
                sflag = 0;
            else
                sflag = datain{3};
            end
            [nc, nw, nm] = size(data);
            nn = 80;
            nrp = 50;
            
            datao.neuron = cell(nc, nw, nm);
            datao.temporal = cell(nc, nw, nm);
            datao.trial = cell(nc, nw, nm);
            datao.lambda = cell(nc, nw, nm);
            datao.relerr = cell(nc, nw, nm);
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmp = data{i, j, k}{flag};
                        if sflag
                            tmp = convn(tmp, (gausswin(9) / sum(gausswin(9)))', 'same');
                        end
                        neur = cell(1, nrp);
                        temp = cell(1, nrp);
                        tria = cell(1, nrp);
                        lamb = cell(1, nrp);
%                         iddel = true(1, 3 * nrp);
                        recerr = zeros(1, 2 * nrp);
                        parfor ii = 1: 2 * nrp
                            lastwarn('');
                            [t, ~, rec] = ncp(tensor(tmp), nn, 'tol', 1e-6, 'max_iter', 800, 'min_iter', 100);
                            recerr(ii) = rec.final.rel_Error;
%                             [warnMsg, warnId] = lastwarn;
%                             if ~isempty(warnMsg)
%                                 iddel(ii) = false;
%                             end
                            [~, idt] = max(t.U{2}, [], 1);
                            [~, idt] = sort(idt);
                            neur{ii} = t.U{1}(:, idt);
                            temp{ii} = t.U{2}(:, idt);
                            tria{ii} = t.U{3}(:, idt);
                            lamb{ii} = t.lambda(idt);
                        end
                        
                        [~, iduse] = sort(recerr);
                        iduse = iduse(1: nrp);
                        datao.neuron{i, j, k} = neur(iduse);
                        datao.temporal{i, j, k} = temp(iduse);
                        datao.trial{i, j, k} = tria(iduse);
                        datao.lambda{i, j, k} = lamb(iduse);
                        datao.relerr{i, j, k} = recerr(iduse);
                    end
                end
            end
            
            dataout = {datao};

        case 'tca_bootstrap'
            data = datain{1};
            [nc, nw, nm] = size(data.temporal);
            nn = 80;
            nrp = 50;
            nbt = 20;
            offset = 0;
            nsb = 1;
            
            datao = data;
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        
                        %%% booststrap %%%
                        neurt = data.neuron{i, j, k};
                        tempt = data.temporal{i, j, k};
                        triat = data.trial{i, j, k};
                        lambt = data.lambda{i, j, k};
                        relet = data.relerr{i, j, k};
                        neur = cell(1, nbt);
                        temp = cell(1, nbt);
                        tria = cell(1, nbt);
                        lamb = cell(1, nbt);
                        rele = zeros(1, nbt);
                        for ii = 1: nbt
                            if nsb > 1
                                idt = randsample(1: nrp, nsb);
                            else
                                idt = ii + offset;
                            end
                            neur{ii} = mean(reshape(cell2mat(neurt(idt)), [], nn, nsb), 3);
                            temp{ii} = mean(reshape(cell2mat(tempt(idt)), [], nn, nsb), 3);
                            tria{ii} = mean(reshape(cell2mat(triat(idt)), [], nn, nsb), 3);
                            lamb{ii} = mean(reshape(cell2mat(lambt(idt)), nn, nsb), 2);
                            rele(ii) = mean(relet(idt));
                        end
                        
                        datao.neuron{i, j, k} = neur;
                        datao.temporal{i, j, k} = temp;
                        datao.trial{i, j, k} = tria;
                        datao.lambda{i, j, k} = lamb;
                        datao.relerr{i, j, k} = rele;
                    end
                end
            end
            
            dataout = {datao};

            
        case 'tca_component_analysis'
            data = datain{1};
            fdata = datain{2};
            fdatas = fdata.s;
            fdatab = fdata.b;
            fdatawp = fdata.wp;
            [nc, nw, nm] = size(data.trial);
            sl = [0, 3, 0; 1, 5, 1];
            thres = 0.9;
            
            scout = cell(nc - 1, nw, nm);
            for i = 2: nc
                for j = 1: nw
                    for k = 1: nm
                        tmps = fdatas{i, j, k};
                        tmpb = fdatab{i, j, k};
                        idt1 = ~isnan(tmps);
                        idt2 = ~isnan(tmpb);
                        tmpwh = tmpb(idt1 & idt2);
                        tmpwh = tmpwh > 0;
                        
                        tmpp = fdatawp{i, j, k};
                        suse = tmps(idt1 & idt2);
                        tmpwp = false(1, length(suse));
                        for ii = 1: length(suse)
                            idtt = find((tmpp > (suse(ii) + sl(1, i) * Fsi) * Fsb /Fsi) & (tmpp < (suse(ii) + sl(2, i) * Fsi) * Fsb / Fsi), 1);
                            if ~isempty(idtt)
                                tmpwp(ii) = true;
                            end
                        end
                        
                        %%% type of behavior variable to use %%%
                        sigwh = double(tmpwh);
%                         sigwh(sigwh == 0) = -1;
                        sigwp = double(tmpwp);
%                         sigwp(sigwp == 0) = -1;
                        sigpwh = 1 - sigwh;
                        sigpwp = 1 - sigwp;
                        whnwp = sigwh .* sigpwp;
                        whnwp(whnwp == 0) = -1;
                        nwhnwp = sigpwh .* sigpwp;
                        nwhnwp(nwhnwp == 0) = -1;
                        whwp = sigwh .* sigwp;
                        whwp(whwp == 0) = -1;
                        nwhwp = sigpwh .* sigwp;
                        nwhwp(nwhwp == 0) = -1;
                        wh = sigwh;
                        wh(wh == 0) = -1;
                        nwh = -1 * wh;
                        wp = sigwp;
                        wp(wp == 0) = -1;
                        nwp = -1 * wp;
                        
                        %%% find corresponding component %%%
                        scout{i - 1, j, k} = cell(1, length(data.trial{i, j, k}));
                        for ii = 1: length(data.trial{i, j, k})
                            trial = data.trial{i, j, k}{ii};
                            idd = idt2(idt1);
                            trialt = trial(idd, :);
                            trit = double(trialt > std(trialt(:)));
                            trit(trit == 0) = -1;
                            
                            tmp = trit .* whnwp';
                            tmp1 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* nwhnwp';
                            tmp2 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* whwp';
                            tmp3 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* nwhwp';
                            tmp4 = sum(tmp == 1, 1) ./ size(trit, 1);
                            
                            tmp = trit .* wh';
                            tmp5 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* nwh';
                            tmp6 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* wp';
                            tmp7 = sum(tmp == 1, 1) ./ size(trit, 1);
                            tmp = trit .* nwp';
                            tmp8 = sum(tmp == 1, 1) ./ size(trit, 1);
                            
                            scout{i - 1, j, k}{ii} = [tmp1; tmp2; tmp3; tmp4; tmp5; tmp6; tmp7; tmp8];
                        end                        
                    end
                end
            end
            
            dataout = {scout};
            
        case 'tca_behav_summary'
            data = datain{1};
            scin = datain{2};
            if length(datain) < 3
                cuse = 4;
            else
                cuse = datain{3};
            end
            [nc, nw, nm] = size(data.trial);
            neur = data.neuron;
            temp = data.temporal;
            tria = data.trial;
            lamb = data.lambda;
            nrp = length(temp{1, 1, 1});
            [ntemp, nn] = size(temp{1, 1, 1}{1});
            tempm = cellfun(@(x) squeeze(mean(reshape(cell2mat(x')', nn, ntemp, []), 3)), temp, 'uniformoutput', false);
%             tempse = cellfun(@(x) squeeze(std(reshape(cell2mat(x')', nn, nt, []), [], 3)) / (2 * sqrt(nrp)), temp, 'uniformoutput', false);
            scm = cellfun(@(x) squeeze(mean(reshape(cell2mat(x')', nn, size(x{1}, 1), []), 3)), scin, 'uniformoutput', false);
%             scse = cellfun(@(x) squeeze(std(reshape(cell2mat(x')', nn, size(x{1}, 1), []), [], 3)) / (2 * sqrt(nrp)), scin, 'uniformoutput', false);
            triam = cellfun(@(x) squeeze(mean(reshape(cell2mat(x')', nn, size(x{1}, 1), []), 3)), tria, 'uniformoutput', false);
            neurm = cellfun(@(x) squeeze(mean(reshape(cell2mat(x')', nn, size(x{1}, 1), []), 3)), neur, 'uniformoutput', false);
            lambm = cellfun(@(x) squeeze(mean(reshape(cell2mat(x'), nn, length(x), []), 2)), lamb, 'uniformoutput', false);
            ncomp = size(scm{1, 1, 1}, 2);
            
            tempmuse = cell(nc - 1, nw, nm);
%             tempseuse = cell(nc - 1, nw, nm);
            scmuse = cell(nc - 1, nw, nm);
%             scseuse = cell(nc - 1, nw, nm);
            triamuse = cell(nc - 1, nw, nm);
            neurmuse = cell(nc - 1, nw, nm);
            lambmuse = cell(nc - 1, nw, nm);
            for i = 2: nc
                for j = 1: nw
                    for k = 1: nm
                        id1 = scm{i - 1, j, k};
                        tempmt = zeros(cuse, ntemp, ncomp);
%                         tempset = zeros(cuse, nt, ncomp);
                        scmt = zeros(cuse, ncomp);
%                         scset = zeros(cuse, ncomp);
                        triamt = cell(1, ncomp);
                        neurmt = cell(1, ncomp);
                        lambmt = cell(1, ncomp);
                        for ii = 1: ncomp
                            [~, id] = sort(id1(:, ii), 'descend');
                            tempmt(:, :, ii) = tempm{i, j, k}(id(1: cuse), :);
%                             tempset(:, :, ii) = tempse{i, j, k}(id(1: cuse), :);
                            scmt(:, ii) = scm{i - 1, j, k}(id(1: cuse), ii);
%                             scset(:, ii) = scse{i - 1, j, k}(id(1: cuse), ii);
                            triamt{ii} = triam{i, j, k}(id(1: cuse), :);
                            neurmt{ii} = neurm{i, j, k}(id(1: cuse), :);
                            lambmt{ii} = lambm{i, j, k}(id(1: cuse), :);
                        end
                        tempmuse{i - 1, j, k} = tempmt;
%                         tempseuse{i - 1, j, k} = tempset;
                        scmuse{i - 1, j, k} = scmt;
%                         scseuse{i - 1, j, k} = scset;
                        triamuse{i - 1, j, k} = triamt;
                        neurmuse{i - 1, j, k} = neurmt;
                        lambmuse{i - 1, j, k} = lambmt;
                    end
                end
            end
            
            dataall.tempm = tempm;
%             dataall.tempse = tempse;
            dataall.scm = scm;
%             dataall.scse = scse;
            
            dataout = {tempmuse, triamuse, neurmuse, lambmuse scmuse, dataall};
            
        case 'tca_weighted_principal_component'
            temp = datain{1};
            lamb = datain{2};
%             tria = datain{2};
%             neur = datain{3};
            [nc, nw, nm] = size(temp);
            
            tempn = temp;
            %%% compute scale score %%%
            for i = 1: nc
                for j = 1: nw
                    for k = 1: nm
                        tmpt = temp{i, j, k};
                        tmpl = cell2mat(lamb{i, j, k});
                        tmpl = reshape(tmpl, size(tmpl, 1), [], size(tmpl, 2));
% %                         tmptr = cell2mat(cellfun(@(x) sum(x, 2), tria{i, j, k}, 'uniformoutput', false));
% %                         tmpn = cell2mat(cellfun(@(x) sum(x, 2), neur{i, j, k}, 'uniformoutput', false));
%                         tmptr = cell2mat(cellfun(@(x) sum(x .* (x > std(x(:))), 2), tria{i, j, k}, 'uniformoutput', false));
%                         tmpn = cell2mat(cellfun(@(x) sum(x .* (x > std(x(:))), 2), neur{i, j, k}, 'uniformoutput', false));
%                         tmptr = cell2mat(cellfun(@(x) sqrt(sum(x .^ 2, 2)), tria{i, j, k}, 'uniformoutput', false));
%                         tmpn = cell2mat(cellfun(@(x) sqrt(sum(x .^ 2, 2)), neur{i, j, k}, 'uniformoutput', false));
% %                         scl = reshape(tmptr .* tmpn, size(tmpn, 1), 1, []);
                        scl = tmpl;
%                         scl = reshape(tmptr, size(tmpn, 1), 1, []);
                        temp{i, j, k} = tmpt .* scl;
                        tempn{i, j, k} = tmpt;
                    end
                end
            end
            
            %%% compute valid components %%%
            tp = cell(nc, nw);
            tpn = cell(nc, nw);
%             rg = [5.2, 2.6];
            rg = [7, 5];
            thres = 3;
            for i = 1: nc
                for j = 1: nw
                    tp{i, j} = [];
                    for k = 1: nm
                        tmp = temp{i, j, k};
                        tmpn = tempn{i, j, k};
%                         tp{i, j} = [tp{i, j}; mean(tmp, 1)];
                        tp{i, j} = [tp{i, j}; tmp];
                        tpn{i, j} = [tpn{i, j}; tmpn];
                    end
                    tmp = tp{i, j};
                    tmpn = tpn{i, j};
                    [cuse, nt, ncomp] = size(tmp);
                    tmp = squeeze(mat2cell(tmp, size(tmp, 1), size(tmp, 2), ones(1, size(tmp, 3))));
                    tmpn = squeeze(mat2cell(tmpn, size(tmpn, 1), size(tmpn, 2), ones(1, size(tmpn, 3))));
%                     tmp = cellfun(@(x) x{1}, tmp, 'uniformoutput', false);
                    for k = 1: ncomp
                        tmpp = tmp{k};
                        [mx, idt] = max(tmpp, [], 2);
%                         med = median(tmpp, 1);
%                         idt1 = (mx / max(med)) <= thres;
%                         idt1 = ~isoutlier(mx);
                        idt1 = true(size(idt));
                        idt = idt <= rg(i) * Fsi + 2 * Fsi;
%                         idt = true(size(idt));
                        tmpp = tmpp(idt & idt1, :);
                        tmp{k} = tmpp;
                        tmpn{k} = tmpn{k}(idt & idt1, :);
                    end
                    tp{i, j} = tmp;
                    tpn{i, j} = tmpn;
                end
            end

            dataout = {temp, tp, tpn};

        case 'tca_component_subcluster'
            temp = datain{1};
            [nc, nw] = size(temp);
            ncomp = length(temp{1});
            nt = size(temp{1}{1}, 2);
            nrp = 50;
            bds = {[0:3,5], 0: 0.4: 2.2};
%             bds = {[0, 2, 4, 5], 0: 0.4: 2.2};
            ctrst = linspace(0, nt, 101);
            ctrs{1} = (ctrst(1: end - 1) + ctrst(2: end)) / 2;
            ctrst = linspace(31, 61, 101);
            ctrs{2} = (ctrst(1: end - 1) + ctrst(2: end)) / 2;
            gk = 15;
            
            clusts = cell(nc, nw);
            dists = cell(nc, nw);
            for i = 1: nc
                for j = 1: nw
                    clusts{i, j} = cell(ncomp, 1);
                    dists{i, j} = [];
                    for k = 1: ncomp
                        tmp = temp{i, j}{k};
                        clusts{i, j}{k} = [];
%                         for ii = 1: nrp
%                             [dout, dsout] = subcluster(tmp);
%                             clusts{i, j}{k} = cat(3, clusts{i, j}{k}, dout);
%                             dists{i, j}(:, k, ii) = dsout;
%                         end

                        tmp = convn(tmp, (gausswin(gk) / sum(gausswin(gk)))', 'same');
                        for ii = 1: size(tmp, 1)
                            tmp(ii, :) = normalize(tmp(ii, :));
                        end
                        if i == 1
                            [~, id] = max(tmp, [], 2);
                        else
                            [~, id] = max(diff(tmp, 1, 2), [], 2);
                        end
%                         [~, id] = max(diff(tmp, 1, 2), [], 2);
%                         [~, id] = max(tmp, [], 2);
                        
                        idall = cell(1, length(bds{i}) - 1);
                        clusts{i, j}{k} = zeros(length(bds{i}) - 1, size(tmp, 2));
                        for ii = 1: length(bds{i}) - 1
                            idall{ii} = id >= (bds{i}(ii) + 2) * Fsi & id < (bds{i}(ii + 1) + 2) * Fsi;
                            if sum(idall{ii}) > 0
                                clusts{i, j}{k}(ii, :, 1) = mean(tmp(idall{ii}, :), 1) * sum(idall{ii}) / length(idall{ii});
                                clusts{i, j}{k}(ii, :, 2) = std(tmp(idall{ii}, :), 1) / 2 * sqrt(sum(idall{ii})) / length(idall{ii});
                            end
                        end
%                         id1 = id <= (bds(i, 1) + 2) * Fsi;
%                         id2 = id > (bds(i, 1) + 2) * Fsi & id <= (bds(i, 2) + 2) * Fsi;
%                         id3 = id > (bds(i, 2) + 2) * Fsi;
                        
%                         clusts{i, j}{k} = [mean(tmp(id1, :), 1); mean(tmp(id2, :), 1); mean(tmp(id3, :), 1)];
%                         clusts{i, j}{k} = [mean(tmp(id1, :), 1) * sum(id1); mean(tmp(id2, :), 1) * sum(id2); mean(tmp(id3, :), 1) * sum(id3)];
                        if i == 1
                            ttmp = ksdensity(id, ctrs{i}, 'bandwidth', 20);
%                             ttmp = ksdensity(id, ctrs{i});
                        else
                            ttmp = ksdensity(id, ctrs{i}, 'bandwidth', 4);
                        end
                        dists{i, j}(k, :) = ttmp / sum(ttmp);
%                         dists{i, j}(k, :) = ttmp / max(ttmp);
                    end
                end
            end
            
            dataout = {clusts, dists};
            
        case 'tca_weighted_principal_component_average'
            tp = datain{1};
            [nc, nw] = size(tp);
            
            tpmout = tp;
            tpseout = tp;
            kn = 20;
            for i = 1: nc
                for j = 1: nw
                    tmp = tp{i, j};
                    tpmout{i, j} = [];
                    tpseout{i, j} = [];
                    for k = 1: length(tmp)
                        tmpp = tmp{k};
                        tmpp = convn(tmpp, (gausswin(kn) / sum(gausswin(kn)))', 'same');
%                         tmpp = tmpp ./ sum(tmpp, 2);
                        tmppt = median(tmpp, 1);
                        tpmout{i, j}(k, :) = tmppt;
%                         tpmout{i, j}(k, :) = tmppt / max(tmppt);
%                         normt = norm(tmppt);
%                         normt = sum(tmppt);
%                         tpmout{i, j}(k, :) = tmppt / normt;
%                         tmppt = median(tmpp, 1);
%                         tpmout{i, j}(k, :) = conv(tmppt, gausswin(kn) / sum(gausswin(kn)), 'same');
%                         tpseout{i, j}(k, :) = std(tmpp, 1) / (2 * sqrt(1 * size(tmpp, 1)) * normt); % 20 repeats %
                        tpseout{i, j}(k, :) = std(tmpp, 1) / (2 * sqrt(1 * size(tmpp, 1))); % 20 repeats %
                    end
                end
            end
            
            dataout = {tpmout, tpseout};

            
            
            

    end
end

%% ------------------------------------------------- %%
%% auxiliary functions %%
function dataout = extract_factor(datain, fflag)
    spks = datain{1};
    gs = datain{2};
    bs = datain{3};
    idxwp = datain{4};
    idxwh = datain{5};
    idxs = datain{6};

    nr = 100;
    datafr = spks;
    spksfr = spks;
    indout = idxs;
    for i = 1: size(datafr, 1)
        for j = 1: size(datafr, 2)
            for k = 1: size(datafr, 3)
                stmp = spks{i, j, k};
                idtwp = idxwp{i, j, k}(:)';
                idtwh = idxwh{i, j, k}(:)';
                idts = idxs{i, j, k}(:)';

                %%% 1: parse behavior combination %%%
                idxall = combine_idx(idtwp, idtwh, idts, fflag);
                indout{i, j, k} = idxall;
                idpool = unique(idxall', 'rows');
                idwppool = idpool(idpool(:, 1) == 1, :);
                idcurr = idxall(1, :);
                
                %%% 2: get weighted signal distribution %%%
                mx = max(stmp(:));
                nn = size(stmp, 1);
                nid = size(idwppool, 1);
                wts = zeros(1, nid);
                if strcmp(fflag, 'extract_pseudo')
                    nedge = 1000;
                    %                 dist = NaN(nid, nedge - 1);
                    %                 distraw = NaN(nid, nedge - 1);
                    dist = NaN(nid, nedge);
                    distraw = NaN(nid, nedge);
                else
                    nedge = 1001;
                    dist = NaN(nid, nedge - 1);
                    distraw = NaN(nid, nedge - 1);
                end
                edges = linspace(0, mx, nedge);
                
                weight = zeros(1, size(idpool, 1));
                tpall = cell(size(idpool, 1), 1);
                for ii = 1: size(idpool, 1)
                    l1 = ismember(idxall', idpool(ii, :), 'rows');
                    l1 = l1(:);
                    tp1 = stmp(:, l1);
                    tp1 = tp1(:);
                    tpall{ii} = tp1;
                    weight(ii) = sum(l1);
                    if strcmp(fflag, 'extract_pseudo')
                        h1 = ksdensity(tp1, edges, 'bandwidth', 2 * mx / (nedge - 1));
                    else
                        h1 = histcounts(tp1, edges);
                    end
                    h1n = h1 / sum(h1);
                    distraw(ii, :) = h1n;
                end

                idl = zeros(nid, length(idcurr));
                idtodo = [];
                for ii = 1: nid
                    l1 = ismember(idxall', idwppool(ii, :), 'rows');
                    idtmp = idwppool(ii, :);
                    idtmp(1) = false;
                    l2 = ismember(idxall', idtmp, 'rows');
                    l1 = l1(:);
                    l2 = l2(:);
                    tp1 = stmp(:, l1);
                    tp2 = stmp(:, l2);
                    tp1 = tp1(:);
                    tp2 = tp2(:);
                    wts(ii) = length(tp1);
                    idl(ii, :) = l1';
                    
                    if ~isempty(tp2) && ~isempty(tp1)
                        if strcmp(fflag, 'extract_pseudo')
                            h1 = ksdensity(tp1, edges, 'bandwidth', 2 * mx / (nedge - 1));
                            h2 = ksdensity(tp2, edges, 'bandwidth', 2 * mx / (nedge - 1));
                            h1n = h1 / sum(h1);
                            h2n = h2 / sum(h2);
                            hd = dist_equalize(h1n, h2n, sum(h1));
                            dist(ii, :) = round(hd);
                        else
                            h1 = histcounts(tp1, edges);
                            h2 = histcounts(tp2, edges);
%                         h1 = histcounts(tp1(tp1 > 0), edges);
%                         h2 = histcounts(tp2(tp2 > 0), edges);
                            h1n = h1 / sum(h1);
                            h2n = h2 / sum(h2);
                            hd = dist_equalize(h1n, h2n, sum(h1));
                            dist(ii, :) = ceil(hd);
                        end
                    else
                        idtodo = [idtodo, ii];
                    end
                end
                
                %%% estimate averaged existing distribution %%%
                ntd = length(idtodo);
                for ii = 1: ntd
                    idtmp = idwppool(idtodo(1), :);
                    idtmp = ismember(idpool, idtmp, 'rows');
%                     ltmp = ismember(idpool, idtmp, 'rows');
                    a = idpool';
                    b = idwppool(idtodo(1), :)';
%                     a(1, :) = false;
                    b(1, :) = false;
                    x = a \ b;
                    if all(x == 0)
                        x = ones(size(x));
                        x = x / sum(x);
                    end
                    wtuse = weight;
                    distuse = distraw;
                    distuse = (distuse' * (x(:) .* wtuse(:)))';
                    distuse = distuse / sum(distuse);
                    h1 = histcounts(tpall{idtmp}, edges);
                    h1n = h1 / sum(h1);
                    hd = dist_equalize(h1n, distuse, sum(h1));
                    dist(idtodo(1), :) = ceil(hd);
                    idtodo = idtodo(2: end);
                end
                
%                 distcurr = (wts / sum(wts)) * dist;
%                 distcurr = distcurr / sum(distcurr);
%                 distcurr = ceil(distcurr * sum(idcurr) * nn);

                %%% 3: random pick signal %%%
                idtt = cell(nr, 1);
                idd = cell(nr, 1);
                for jj = 1: nr
                    idtt{jj} = false(size(stmp));
                    idd{jj} = [];
                end
                
                for ii = 1: nid
                    ll1 = repmat(idl(ii, :), nn, 1);
                    for jj = 1: nedge - 1
                        edget = edges(jj: jj + 1);
                        idt1 = find(ll1 > 0 & stmp > edget(1) & stmp <= edget(2));
                        for kk = 1: nr
                            idd{kk} = [idd{kk}; idt1(randsample(length(idt1), min(length(idt1), dist(ii, jj))))];
                        end
                    end
                end
                
                
                for jj = 1: nr
                    idtt{jj}(idd{jj}) = true;
                end
                
                %%% 4: get updated spk info %%%
                dtt = zeros([size(stmp), nr]);
                for ii = 1: nr
                    tmp = stmp .* idtt{ii};
                    dtt(:, :, ii) = tmp;
                end
                t = nanmean(dtt, 3);
                spksfr{i, j, k} = t;

                %%% 5: get updated fluorescence %%%
                for ii = 1: nn
                    datafr{i, j, k}(ii, :) = ar_integrate(t(ii, :), gs{i, j, k}(:, ii)) + bs{i, j, k}(ii);
%                     disp(num2str(ii))
                end
                disp([num2str(i), ' ', num2str(j), ' ', num2str(k)])
            end
        end
    end

    if strcmp(fflag, 'extract_pseudo')
        for i = 1: size(datafr, 1)
            for j = 1: size(datafr, 2)
                for k = 1: size(datafr, 3)
                    idt = indout{i, j, k};
                    [ll, nn] = bwlabeln(idt);
                    indout{i, j, k} = zeros(1, nn);
                    for kk = 1: nn
                        indout{i, j, k}(kk) = find(ll == kk, 1);
                    end
                end
            end
        end
        dataout = {spksfr, datafr, indout};
    else
        dataout = {spksfr, datafr};
    end
end

function idxall = combine_idx(idtwp, idtwh, idts, fflag)
    switch fflag
        case 'extract_wipe'
            idxall = [idtwp; idtwh; idts];
        case 'extract_whisk'
%             idxall = [idtwh; idts];
            idxall = [idtwh; idts; idtwp];
        case 'extract_stim'
            idxall = [idts; idtwp; idtwh];
%             idxall = idts;    
        case 'extract_pseudo'
            iduse = ~idts & ~idtwh & ~idtwp;
            tmp = find(iduse);
            iduse(tmp(randsample(length(tmp), round(length(tmp) / 2)))) = false;
            iduse = double(iduse);
            iduse(idts | idtwh | idtwp) = 2;
            idxall = iduse;
    end
end

function hd = dist_equalize(h1n, h2n, n1)
    nn = n1;
    thres = 1 / n1;
    rt = 0.1;
    hd = zeros(size(h1n));
    idreal = h1n > h2n;
    cc = 1;
    while 1
        d = abs(h1n - h2n);
        idcurr = h1n > h2n;
        ratio = sum(idreal & idcurr) / sum(idreal | idcurr);
        if all(d < thres) || nn < n1 / 2 || ratio < 0.9999
            break
        else
            dt = rt * max(0, h1n - h2n);
            hd = hd + dt * nn;
            nn = nn * (1 - sum(dt));
            h1n = h1n - dt;
            h1n = h1n / sum(h1n);
        end
        cc = cc + 1;
    end
    
%     hd = ceil(hd);
end

function [datout, idout] = subcluster_recon(sig, trial, neuron, iduse)
    %%% cluster iduse %%%
    if sum(iduse) == 0
        datout = [];
        idout = [];
    else
        sig1 = sig(:, iduse);
        tri1 = trial(:, iduse);
        neu1 = neuron(:, iduse);
        eva = evalclusters(sig1', 'kmeans', 'CalinskiHarabasz', 'KList', 1: 4);
        nc = eva.OptimalK;
        if isnan(nc)
            idt = 1;
            nc = 1;
        else
            [idt, C] = kmeans(sig1', nc);
        end
        
        idout = cell(1, nc);
        datout = cell(1, nc);
        for i = 1: nc
            sigtmp = sig1(:, idt == i);
            tritmp = tri1(:, idt == i);
            neutmp = neu1(:, idt == i);
            tmp.U = {neutmp, sigtmp, tritmp};
            dataout = tca_analysis({tmp}, 'compute_reconst');
            rec = dataout{1};
            tca = ncp(tensor(rec), 1);
            idout{i} = find(idt == i);
            datout{i} = tca;
        end
    end
end

function [datout, dists] = subcluster(sig)
    %%% cluster iduse %%%
    for i = 1: size(sig, 1)
        sig(i, :) = normalize(sig(i, :));
    end
%     eva = evalclusters(sig, 'kmeans', 'CalinskiHarabasz', 'KList', 1: 3);
%     eva = evalclusters(sig, 'kmeans', 'daviesbouldin', 'KList', 1: 5);
%     eva = evalclusters(sig, 'kmeans', 'gap', 'KList', 1: 5);
%     [~, id] = max(sig, [], 2);
%     eva = evalclusters(id, 'gmdistribution', 'silhouette', 'KList', 1: 4);
%     nc = eva.OptimalK;
    nc = 3;
    if isnan(nc)
        idt = 1;
        nc = 1;
    else
        [idt, C] = kmeans(sig, nc);
    end

    datout = zeros(nc, size(sig, 2));
    dists = zeros(nc, 1);
    for i = 1: nc
        datout(i, :) = smooth(mean(sig(idt == i, :), 1), 15);
        dists(i) = sum(idt == i);
    end
    
    [~, id] = max(datout, [], 2);
    [~, id] = sort(id);
    datout = datout(id, :);
    dists = dists(id);
    
end

function [scdd, iddd] = crossent(tar, sig)
    cr = zeros(1, size(sig, 2)); 
    for ii = 1: size(sig, 2)
        cr(ii) = crossentropy(sig(:, ii), tar);
    end
    [scdd, iddd] = sort(cr);
end












%% old functions %%
function dataout = remove_factor_old(datain, Fsi, fflag)
    spks = datain{1};
    gs = datain{2};
    bs = datain{3};
    idxwp = datain{4};
    idxwh = datain{5};
    idxs = datain{6};

    nr = 100;
    datafr = spks;
    spksfr = spks;
    for i = 1: size(datafr, 1)
        for j = 1: size(datafr, 2)
            for k = 1: size(datafr, 3)
                stmp = spks{i, j, k};
                idtwp = idxwp{i, j, k}(:)';
                idtwh = idxwh{i, j, k}(:)';
                idts = idxs{i, j, k}(:)';

                %%% first parse behavior combination %%%
                idxall = combine_idx(idtwp, idtwh, idts, fflag);
                idpool = unique(idxall', 'rows');
                idwppool = idpool(idpool(:, 1) == 1, :);
                idtt = cell(nr, 1);
                for jj = 1: nr
                    idtt{jj} = true(size(stmp));
                end
                for ii = 1: size(idwppool, 1)
                    l1 = ismember(idxall', idwppool(ii, :), 'rows');
                    idtmp = idwppool(ii, :);
                    idtmp(1) = false;
                    l2 = ismember(idxall', idtmp, 'rows');
                    l1 = l1(:);
                    l2 = l2(:);
                    tp1 = stmp(:, l1);
                    tp2 = stmp(:, l2);
%                     tp1 = stmp(1:10, l1);
%                     tp2 = stmp(1:10, l2);
                    tp1 = tp1(:);
                    tp2 = tp2(:);
                    
                    if ~isempty(tp2) && ~isempty(tp1)
                        mx = max([tp1; tp2]);
                        edges = linspace(0, mx, 8 * Fsi + 1);
                        h1 = histcounts(tp1, edges);
                        h2 = histcounts(tp2, edges);
                        %                     h1n = h1 / sum(h1(1));
                        %                     h2n = h2 / sum(h2(1));
                        h1n = h1 / sum(h1);
                        h2n = h2 / sum(h2);
                        %                     hd = ceil(max(0, h1n - h2n) * length(tp1) / sum(h1n));
                        hd = dist_equalize(h1n, h2n, sum(h1));
                        
                        idd = cell(nr, length(hd));
                        ll1 = repmat(l1', size(stmp, 1), 1);
                        for jj = 1: size(idd, 2)
                            edget = edges(jj: jj + 1);
                            idt1 = find(ll1 > 0 & stmp > edget(1) & stmp <= edget(2));
                            for kk = 1: nr
                                idd{kk, jj} = idt1(randsample(length(idt1), min(length(idt1), hd(jj))));
                            end
                        end
                        
                        for jj = 1: nr
                            for kk = 1: size(idd, 2)
                                idtt{jj}(idd{jj, kk}) = false;
                            end
                        end
                    end
                end

                %%% get updated spk info %%%
                dtt = zeros([size(stmp), nr]);
                for ii = 1: nr
                    tmp = stmp .* idtt{ii};
                    dtt(:, :, ii) = tmp;
                end
                t = nanmean(dtt, 3);
                spksfr{i, j, k} = t;

                %%% get updated fluorescence %%%
                for ii = 1: size(stmp, 1)
                    datafr{i, j, k}(ii, :) = ar_integrate(t(ii, :), gs{i, j, k}(:, ii)) + bs{i, j, k}(ii);
                end
            end
        end
    end

    dataout = {spksfr, datafr};
end

function dataout = extract_factor_old(datain, fflag)
    spks = datain{1};
    gs = datain{2};
    bs = datain{3};
    idxwp = datain{4};
    idxwh = datain{5};
    idxs = datain{6};

    nr = 100;
    datafr = spks;
    spksfr = spks;
    for i = 1: size(datafr, 1)
        for j = 1: size(datafr, 2)
            for k = 1: size(datafr, 3)
                stmp = spks{i, j, k};
                idtwp = idxwp{i, j, k}(:)';
                idtwh = idxwh{i, j, k}(:)';
                idts = idxs{i, j, k}(:)';

                %%% 1: parse behavior combination %%%
                idxall = combine_idx(idtwp, idtwh, idts, fflag);
                idpool = unique(idxall', 'rows');
                idwppool = idpool(idpool(:, 1) == 1, :);
                idcurr = idxall(1, :);
                
                %%% 2: get weighted signal distribution %%%
                mx = max(stmp(:));
                nn = size(stmp, 1);
                nedge = 101;
                edges = linspace(0, mx, nedge);
                nid = size(idwppool, 1);
                wts = zeros(1, nid);
                dist = NaN(nid, nedge - 1);
                
                for ii = 1: nid
                    l1 = ismember(idxall', idwppool(ii, :), 'rows');
                    idtmp = idwppool(ii, :);
                    idtmp(1) = false;
                    l2 = ismember(idxall', idtmp, 'rows');
                    l1 = l1(:);
                    l2 = l2(:);
                    tp1 = stmp(:, l1);
                    tp2 = stmp(:, l2);
                    tp1 = tp1(:);
                    tp2 = tp2(:);
                    wts(ii) = length(tp1);
                    
                    if ~isempty(tp2) && ~isempty(tp1)
%                         h1 = ksdensity(tp1, edges, 'bandwidth', 0.001);
%                         h2 = ksdensity(tp2, edges, 'bandwidth', 0.001);
                        h1 = histcounts(tp1, edges);
                        h2 = histcounts(tp2, edges);
                        h1n = h1 / sum(h1);
                        h2n = h2 / sum(h2);
                        %                     hd = ceil(max(0, h1n - h2n) * length(tp1) / sum(h1n));
                        hd = dist_equalize(h1n, h2n, sum(h1));
                        dist(ii, :) = hd;
                    end
                end
                
                distcurr = (wts / sum(wts)) * dist;
%                 distcurr = distcurr / sum(distcurr);
%                 distcurr = ceil(distcurr * sum(idcurr) * nn);

                %%% 3: random pick signal %%%
                idtt = cell(nr, 1);
                for jj = 1: nr
                    idtt{jj} = false(size(stmp));
                end
                idd = cell(nr, nedge);
                ll1 = repmat(idcurr, nn, 1);
                for jj = 1: nedge - 1
                    edget = edges(jj: jj + 1);
                    idt1 = find(ll1 > 0 & stmp >= edget(1) & stmp <= edget(2));
                    for kk = 1: nr
                        idd{kk, jj} = idt1(randsample(length(idt1), min(length(idt1), distcurr(jj))));
                    end
                end
                
                for jj = 1: nr
                    for kk = 1: size(idd, 2)
                        idtt{jj}(idd{jj, kk}) = true;
                    end
                end
                
                %%% 4: get updated spk info %%%
                dtt = zeros([size(stmp), nr]);
                for ii = 1: nr
                    tmp = stmp .* idtt{ii};
                    dtt(:, :, ii) = tmp;
                end
                t = nanmean(dtt, 3);
                spksfr{i, j, k} = t;

                %%% 5: get updated fluorescence %%%
                for ii = 1: nn
                    datafr{i, j, k}(ii, :) = ar_integrate(t(ii, :), gs{i, j, k}(:, ii)) + bs{i, j, k}(ii);
                end
            end
        end
    end

    dataout = {spksfr, datafr};
end

function hd = dist_equalize_old(h1n, h2n, n)
    nn = n;
    thres = 1e-4;
    rt = 0.1;
    hd = zeros(size(h1n));
    while 1
        d = abs(h1n - h2n);
        if all(d < thres) || nn < 1
            break
        else
            dt = rt * max(0, h1n - h2n);
            hd = hd + dt * nn;
            nn = nn * (1 - sum(dt));
            h1n = h1n - dt;
            h1n = h1n / sum(h1n);
        end
    end
    
%     hd = ceil(hd);
end

function cv = min_cover(mtx, ref)
    dt = mtx .* ref(:);
    scr = sum(dt, 1) / sum(ref);
    [~, id] = sort(scr, 'descend');
    cv = [];
    curr = false(size(ref(:)));
    for i = 1: length(scr)
        tmp = dt(:, id(i));
        curr = curr | tmp;
        cv = [cv, id(i)];
        if all(curr == ref)
            break
        end
    end
end
